import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from scipy.stats import gaussian_kde
import os
from svglib.svglib import svg2rlg
from reportlab.graphics import renderPDF

# PlotsLowRank.py: reads CSV outputs and generates comparison plots
# Save via SVG→PDF, then remove SVG

# Plot settings
width = 700
height = 400
margin = dict(l=20, r=20, t=20, b=20)
font = dict(family='Times New Roman', size=24, color='black')
grid = dict(showgrid=True, gridcolor='lightgray',
            zeroline=True, zerolinecolor='lightgray')

def plot_lowrank(csv_file: str, output_basename: str):
    df = pd.read_csv(csv_file)
    H_vals = sorted(df['param'].unique())

    grad = px.colors.sample_colorscale([[0.0, '#0000FF'], [0.5, '#800080'], [1.0, '#FF0000']], len(H_vals))

    xs = np.linspace(-10, 10, 200)
    fig = go.Figure()
    for idx, H_val in enumerate(H_vals):
        sub = df[df['param'] == H_val]
        kde_theo = gaussian_kde(sub['y_theo'].values)
        fig.add_trace(go.Scatter(
            x=xs, y=kde_theo(xs), mode='lines',
            name=f'∞-width, H={H_val}',
            line=dict(color=grad[idx], dash='solid')
        ))
        # fig.add_trace(go.Histogram(
        #     x=sub['y_emp'].values,
        #     histnorm='probability density',
        #     name=f'n={H_val*64}, H={H_val}',
        #     marker_color=grad[idx],
        #     opacity=0.5,
        #     legendrank=1
        # ))

        # Empirical KDEs (dotted gradient)
        # for idx, n_val in enumerate(n_vals):
        arr = df[df['param'] == H_val]['y_emp'].values
        kde_emp = gaussian_kde(arr)
        fig.add_trace(go.Scatter(
            x=xs, y=kde_emp(xs), mode='lines', name=f'n={H_val*64}, H={H_val}',
            line=dict(color=grad[idx], dash='dot'),
            legendrank=1
        ))

    fig.update_layout(
        font=font,
        width=width, height=height,
        margin=margin,
        plot_bgcolor='white', paper_bgcolor='white',
        xaxis=grid, yaxis=grid
    )

    svg_file = f"{output_basename}.svg"
    pdf_file = f"{output_basename}.pdf"
    fig.write_image(svg_file)
    drawing = svg2rlg(svg_file)
    renderPDF.drawToFile(drawing, pdf_file)
    os.remove(svg_file)


if __name__ == '__main__':
    plot_lowrank('data_vary_n_and_H.csv', 'lowrank')
